{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "2a59b2aa-b879-4753-9505-e7e36a9e829b",
   "metadata": {},
   "source": [
    "## Pattern I: Using LLM to detect intent and recognize/extract entities followed by Text-to-SQL generation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cdc0e4b0-8d1e-45a0-b73f-c45854fb91c8",
   "metadata": {},
   "source": [
    "#### Imports "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "926b288e-44b7-4217-a92b-579bb050f7a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain.prompts.chat import SystemMessagePromptTemplate\n",
    "from langchain.prompts.chat import HumanMessagePromptTemplate\n",
    "from langchain.prompts.chat import AIMessagePromptTemplate\n",
    "from langchain.prompts.chat import ChatPromptTemplate\n",
    "from langchain.chat_models import ChatVertexAI\n",
    "from google.cloud import bigquery\n",
    "import pandas as pd\n",
    "import logging\n",
    "import os"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6dfa361e-1e53-4a8d-a724-60b5a1b91c0b",
   "metadata": {},
   "source": [
    "##### Setup logging"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "695b22a4-b0be-4c32-9221-4b3c048ff0a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "logger = logging.getLogger(__name__)\n",
    "logger.setLevel(logging.INFO)\n",
    "logger.addHandler(logging.StreamHandler())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ab59b237-cd91-4724-ba21-43062d682164",
   "metadata": {},
   "source": [
    "#### Setup essentials "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "94d12698",
   "metadata": {},
   "outputs": [],
   "source": [
    "SERVICE_ACCOUNT_CREDENTIALS = './../credentials/vai-key.json'\n",
    "os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = SERVICE_ACCOUNT_CREDENTIALS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "38589da0-d8e0-4ae2-b787-ef6ccd87a9b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "PROJECT = 'arun-genai-bb'\n",
    "LOCATION = 'us-central1'\n",
    "MODEL_NAME = 'codechat-bison@latest'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "d7079800-0e0b-4692-9fb9-d56f31aaffae",
   "metadata": {},
   "outputs": [],
   "source": [
    "bq = bigquery.Client()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "d04c7f3e-14d2-423e-8559-9edc0d457a3a",
   "metadata": {},
   "outputs": [],
   "source": [
    "llm = ChatVertexAI(project=PROJECT, \n",
    "                   location=LOCATION, \n",
    "                   model_name=MODEL_NAME, \n",
    "                   temperature=0.0, \n",
    "                   max_output_tokens=256)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "cd1c24a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "user_query = \"Provide a list of all flight reservations from October 10th to October 15th, 2023\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "52d65a6e-cb82-4640-987c-a87803303f95",
   "metadata": {},
   "source": [
    "### Step 1: Identify the `intent` of user's query"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "609c1870-55b3-4752-8683-ef79c2fc582f",
   "metadata": {},
   "source": [
    "##### Load example prompt and completion pairs needed for intent detection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "edd80864-902e-4bd6-98ef-fa3520c3625a",
   "metadata": {},
   "outputs": [],
   "source": [
    "messages = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "3203ddec-7303-4a3d-88e8-b232251e58fb",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>prompt</th>\n",
       "      <th>intent</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Need all the bookings from 10th to 15th Octobe...</td>\n",
       "      <td>RETRIEVE_RESERVATIONS</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Could you retrieve reservations for mid-Octobe...</td>\n",
       "      <td>RETRIEVE_RESERVATIONS</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Let’s see all the October reservations from 10...</td>\n",
       "      <td>RETRIEVE_RESERVATIONS</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Any reservations from 10/10/2023 to 15/10/2023?</td>\n",
       "      <td>RETRIEVE_RESERVATIONS</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>I’m looking for bookings between the second an...</td>\n",
       "      <td>RETRIEVE_RESERVATIONS</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                              prompt                 intent\n",
       "0  Need all the bookings from 10th to 15th Octobe...  RETRIEVE_RESERVATIONS\n",
       "1  Could you retrieve reservations for mid-Octobe...  RETRIEVE_RESERVATIONS\n",
       "2  Let’s see all the October reservations from 10...  RETRIEVE_RESERVATIONS\n",
       "3    Any reservations from 10/10/2023 to 15/10/2023?  RETRIEVE_RESERVATIONS\n",
       "4  I’m looking for bookings between the second an...  RETRIEVE_RESERVATIONS"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "examples = pd.read_csv('./../data/few-shot/prompts_intent.csv')\n",
    "examples.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "7602708f-3523-4771-b2c4-a89de7d5e7f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "template = \"You are a helpful assistant capable of detecting the intent behind a user's query.\"\n",
    "system_message_prompt = SystemMessagePromptTemplate.from_template(template)\n",
    "messages.append(system_message_prompt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "21e08ee5-cec9-45cf-98a2-1f8a98e52649",
   "metadata": {},
   "outputs": [],
   "source": [
    "for _, row in examples.iterrows():\n",
    "    prompt, completion = row\n",
    "    human_message = HumanMessagePromptTemplate.from_template(prompt)\n",
    "    messages.append(human_message)\n",
    "    ai_message = AIMessagePromptTemplate.from_template(completion)\n",
    "    messages.append(ai_message)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "6116d0f1-47e0-4282-a5c2-d0ae7317b209",
   "metadata": {},
   "outputs": [],
   "source": [
    "human_template = \"{user_query}\"\n",
    "human_message = HumanMessagePromptTemplate.from_template(human_template)\n",
    "messages.append(human_message)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "c019cd85-539a-4030-ba83-0db73701aa69",
   "metadata": {},
   "outputs": [],
   "source": [
    "chat_prompt = ChatPromptTemplate.from_messages(messages)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "36b160c9-33f5-4978-85f8-1d049922dab6",
   "metadata": {},
   "outputs": [],
   "source": [
    "request = chat_prompt.format_prompt(user_query=user_query).to_messages()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "06dd43b3-d705-4637-8721-44f23cfc3d75",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "RETRIEVE_RESERVATIONS\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 45.2 ms, sys: 6.65 ms, total: 51.8 ms\n",
      "Wall time: 4.46 s\n"
     ]
    }
   ],
   "source": [
    "%%time \n",
    "\n",
    "response = llm(request)\n",
    "intent = response.content.strip()\n",
    "logger.info(intent)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4bfea293-0f0a-40af-9905-f5c041c97c22",
   "metadata": {},
   "source": [
    "### Step 2: Extract the entities from the user query"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f0c68a41-37ae-43d2-bcc2-5baa49f8fe64",
   "metadata": {},
   "source": [
    "Load example prompt and completion pairs needed for entity recognition"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "882df9a4-0bfa-4520-8fad-96f094162421",
   "metadata": {},
   "outputs": [],
   "source": [
    "messages = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "87ef8f0e-c78c-41bc-956d-11011c1d3809",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>prompt</th>\n",
       "      <th>entities</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Can you show me all the reservations from Octo...</td>\n",
       "      <td>Start Date:October 10th, 2023|End Date:October...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>What bookings do we have from 10/10/2023 to 10...</td>\n",
       "      <td>Start Date:10/10/2023|End Date:10/15/2023</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Show the reservations occurring between the se...</td>\n",
       "      <td>Start Date:October 8th, 2023|End Date:October ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>List all bookings that are happening from Octo...</td>\n",
       "      <td>Start Date:October 10, 2023|End Date:October 1...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Fetch the reservations from the second week of...</td>\n",
       "      <td>Start Date:October 8th, 2023|End Date:October ...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                              prompt  \\\n",
       "0  Can you show me all the reservations from Octo...   \n",
       "1  What bookings do we have from 10/10/2023 to 10...   \n",
       "2  Show the reservations occurring between the se...   \n",
       "3  List all bookings that are happening from Octo...   \n",
       "4  Fetch the reservations from the second week of...   \n",
       "\n",
       "                                            entities  \n",
       "0  Start Date:October 10th, 2023|End Date:October...  \n",
       "1          Start Date:10/10/2023|End Date:10/15/2023  \n",
       "2  Start Date:October 8th, 2023|End Date:October ...  \n",
       "3  Start Date:October 10, 2023|End Date:October 1...  \n",
       "4  Start Date:October 8th, 2023|End Date:October ...  "
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "examples = pd.read_csv('./../data/few-shot/prompts_ner.csv')\n",
    "examples.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "6d2ab3f8-b094-41ae-b1b8-0c73c7cf0745",
   "metadata": {},
   "outputs": [],
   "source": [
    "template = \"You are a helpful assistant capable of performing named entity recognition.\"\n",
    "system_message_prompt = SystemMessagePromptTemplate.from_template(template)\n",
    "messages.append(system_message_prompt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "90a035f6-4a09-443b-8dd6-418bc75e89a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "for _, row in examples.iterrows():\n",
    "    prompt, completion = row\n",
    "    human_message = HumanMessagePromptTemplate.from_template(prompt)\n",
    "    messages.append(human_message)\n",
    "    ai_message = AIMessagePromptTemplate.from_template(completion)\n",
    "    messages.append(ai_message)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "ee401a5b-73b0-4b51-b9b1-b4fb7ddd74b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "human_template = \"{user_query} Standardize the date format to YYYY-MM-DD.\"\n",
    "human_message = HumanMessagePromptTemplate.from_template(human_template)\n",
    "messages.append(human_message)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "025b88c5-ba1b-434d-a407-918930af6ef0",
   "metadata": {},
   "outputs": [],
   "source": [
    "chat_prompt = ChatPromptTemplate.from_messages(messages)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "128b1cde-ce54-4cf2-8a6b-ee7e8da50e32",
   "metadata": {},
   "outputs": [],
   "source": [
    "request = chat_prompt.format_prompt(user_query=user_query).to_messages()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "9cb84887-02d2-42da-848b-b04b40e22b0c",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Start Date:2023-10-10|End Date:2023-10-15\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 7.63 ms, sys: 3.49 ms, total: 11.1 ms\n",
      "Wall time: 3.39 s\n"
     ]
    }
   ],
   "source": [
    "%%time \n",
    "\n",
    "response = llm(request)\n",
    "entities = response.content.strip()\n",
    "logger.info(entities)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a6249778-0e8a-44c5-bcb1-dc3471258cfd",
   "metadata": {},
   "source": [
    "### Step 3: Map Intent to table names"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "0df9baa0-b708-4c3d-bf81-5feebbd212c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "messages = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "0ae47bfc-b586-4b88-8d2a-554d208038b6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>intent</th>\n",
       "      <th>mapped_tables</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>RETRIEVE_RESERVATIONS</td>\n",
       "      <td>reservations|flights</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>IDENTIFY_RECENT_CUSTOMERS</td>\n",
       "      <td>reservations|customers</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>CALCULATE_REVENUE</td>\n",
       "      <td>reservations|transactions</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>FIND_PEAK_DEPARTURE_MONTHS</td>\n",
       "      <td>flights</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>GROUP_AND_COUNT_CUSTOMERS_BY_AGE</td>\n",
       "      <td>customers</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                             intent              mapped_tables\n",
       "0             RETRIEVE_RESERVATIONS       reservations|flights\n",
       "1         IDENTIFY_RECENT_CUSTOMERS     reservations|customers\n",
       "2                 CALCULATE_REVENUE  reservations|transactions\n",
       "3        FIND_PEAK_DEPARTURE_MONTHS                    flights\n",
       "4  GROUP_AND_COUNT_CUSTOMERS_BY_AGE                  customers"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "examples = pd.read_csv('./../data/few-shot/intent_to_table_mapping.csv')\n",
    "examples.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "5a15fc4d-ec8d-4f55-8063-c388629fab9b",
   "metadata": {},
   "outputs": [],
   "source": [
    "template = \"You are a helpful assistant capable of mapping detected intent to the correct list of BigQuery tables.\"\n",
    "system_message_prompt = SystemMessagePromptTemplate.from_template(template)\n",
    "messages.append(system_message_prompt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "e360225d-7880-486d-a16e-07cf73b6a7a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "for _, row in examples.iterrows():\n",
    "    prompt, completion = row\n",
    "    human_message = HumanMessagePromptTemplate.from_template(prompt)\n",
    "    messages.append(human_message)\n",
    "    ai_message = AIMessagePromptTemplate.from_template(completion)\n",
    "    messages.append(ai_message)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "5466b959-ed90-4e15-8130-beeb503e8f83",
   "metadata": {},
   "outputs": [],
   "source": [
    "human_template = \"{user_intent}\"\n",
    "human_message = HumanMessagePromptTemplate.from_template(human_template)\n",
    "messages.append(human_message)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "6fbf6e62-4909-42d1-a5ce-4539b90fc49f",
   "metadata": {},
   "outputs": [],
   "source": [
    "chat_prompt = ChatPromptTemplate.from_messages(messages)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "790348fd-22d9-4491-94bc-ae7de28d1756",
   "metadata": {},
   "outputs": [],
   "source": [
    "request = chat_prompt.format_prompt(user_intent=intent).to_messages()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "3e245c6a-f3e4-44c9-a416-e2fc03095755",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "reservations|flights\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 4.22 ms, sys: 2.26 ms, total: 6.48 ms\n",
      "Wall time: 709 ms\n"
     ]
    }
   ],
   "source": [
    "%%time \n",
    "\n",
    "response = llm(request)\n",
    "tables = response.content.strip()\n",
    "logger.info(tables)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fc050c78",
   "metadata": {},
   "source": [
    "### Step 4: Load and filter table schemas"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "962f7ecb",
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_files_from_dir(directory):\n",
    "    if not os.path.exists(directory):\n",
    "        logger.warn(f\"The directory {directory} does not exist!\")\n",
    "        return {}\n",
    "\n",
    "    # Create an empty dictionary to store filename and content\n",
    "    files_dict = {}\n",
    "\n",
    "    # Iterate over each file in the directory\n",
    "    for filename in os.listdir(directory):\n",
    "        file_path = os.path.join(directory, filename)\n",
    "\n",
    "        # Ensure it's a file and not a sub-directory or other entity\n",
    "        if os.path.isfile(file_path):\n",
    "            with open(file_path, 'r', encoding='utf-8') as file:\n",
    "                content = file.read()\n",
    "                filename = filename.split('.txt')[0]\n",
    "                files_dict[filename] = content\n",
    "\n",
    "    return files_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "99b60a49",
   "metadata": {},
   "outputs": [],
   "source": [
    "directory_path = './../data/text-schema/'\n",
    "table_schemas = read_files_from_dir(directory_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "aee2e375",
   "metadata": {},
   "outputs": [],
   "source": [
    "table_names = tables.split('|')\n",
    "filtered_table_schemas = {}\n",
    "\n",
    "for table_name in table_names:\n",
    "    if table_name in table_schemas.keys():\n",
    "        filtered_table_schemas[table_name] = table_schemas[table_name]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "b1545fe5",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "----\n",
      "Reservations Table:\n",
      "Description:\n",
      "The Reservations table keeps track of all flight reservations made by customers. Each record represents a unique reservation, detailing the customer, flight, reservation time, and status.\n",
      "----\n",
      "Columns:\n",
      "--\n",
      "reservation_id:\n",
      "Description: A unique identifier for each reservation made on the platform.\n",
      "Usage: This ID ensures that each reservation is distinct and can be referenced for customer inquiries, modifications, and operational tracking.\n",
      "Type: INT64\n",
      "--\n",
      "customer_id:\n",
      "Description: A reference to a customer from the Customers table who made the reservation.\n",
      "Usage: Establishes which customer made a specific reservation, aiding in personalized user experiences, communication, and support.\n",
      "Type: INT64\n",
      "--\n",
      "flight_id:\n",
      "Description: Refers to a specific flight from the Flights table.\n",
      "Usage: Ensures that the reservation corresponds to a specific flight, aiding in managing flight capacities and customer communications.\n",
      "Type: INT64\n",
      "--\n",
      "reservation_datetime:\n",
      "Description: Indicates the date and time when the reservation was made.\n",
      "Usage: Helps in tracking reservation timelines for data analysis and customer communication, and can be used to manage reservation expiries or follow-ups.\n",
      "Type: DATETIME\n",
      "--\n",
      "status:\n",
      "Description: The current status of the reservation (e.g., Confirmed, Cancelled, Pending, etc.)\n",
      "Usage: Provides customers and operational teams with crucial information regarding the state of the reservation, and assists in managing operational processes and customer communications.\n",
      "Type: STRING\n",
      "------\n",
      "Flights Table:\n",
      "Description:\n",
      "The Flights table records all available flights and their relevant details. Each record represents a unique flight, including details about its origin, destination, departure, arrival, carrier, and price.\n",
      "----\n",
      "Columns:\n",
      "--\n",
      "flight_id:\n",
      "Description: A unique identifier for each flight.\n",
      "Usage: This ID ensures each flight is distinct and can be referenced for booking, customer inquiries, and operational purposes.\n",
      "Type: INT64\n",
      "--\n",
      "origin:\n",
      "Description: The departure location of the flight.\n",
      "Usage: Helps customers and operational teams identify where the flight begins and assists in planning and managing flight schedules.\n",
      "Type: STRING\n",
      "--\n",
      "destination:\n",
      "Description: The arrival location of the flight.\n",
      "Usage: Aids customers in finding flights that reach their desired destination and helps operational teams manage and plan flight schedules.\n",
      "Type: STRING\n",
      "--\n",
      "departure_datetime:\n",
      "Description: The date and time when the flight is scheduled to depart.\n",
      "Usage: Provides customers with the information needed to plan their journey and helps operational teams manage and track flights.\n",
      "Type: DATETIME\n",
      "--\n",
      "arrival_datetime:\n",
      "Description: The date and time when the flight is scheduled to arrive at the destination.\n",
      "Usage: Facilitates customers in planning their schedules around their arrival and assists operational teams in managing gate assignments and staff scheduling.\n",
      "Type: DATETIME\n",
      "--\n",
      "carrier:\n",
      "Description: The airline carrier operating the flight.\n",
      "Usage: Allows customers to select flights based on their preferred carrier and helps manage partnerships and operational logistics related to specific carriers.\n",
      "Type: STRING\n",
      "--\n",
      "price:\n",
      "Description: The price of a ticket on the flight.\n",
      "Usage: Enables customers to evaluate and purchase flights based on their budget and assists in revenue management and pricing strategies for the carrier.\n",
      "Type: FLOAT64\n",
      "--\n"
     ]
    }
   ],
   "source": [
    "filtered_table_schemas_text = []\n",
    "\n",
    "for schema in filtered_table_schemas.values():\n",
    "    filtered_table_schemas_text.append(schema)\n",
    "\n",
    "filtered_table_schemas_text = ''.join(filtered_table_schemas_text)\n",
    "logger.info(filtered_table_schemas_text)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4a47dd88-7d05-400f-8bb7-ca434ab2fe4f",
   "metadata": {},
   "source": [
    "### Step 5: Text-to-SQL generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "480c2ac3-f0fa-4d67-be8b-ea7fd30c4bc1",
   "metadata": {},
   "outputs": [],
   "source": [
    "messages = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "57f0f5d0-c600-4250-924a-c7b5e6e0f5c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "template = \"You are a SQL master expert capable of writing complex SQL query in BigQuery.\"\n",
    "system_message_prompt = SystemMessagePromptTemplate.from_template(template)\n",
    "messages.append(system_message_prompt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "e0151e93-311a-452e-9824-1d70f333bc39",
   "metadata": {},
   "outputs": [],
   "source": [
    "human_template = \"\"\"Please construct a SQL query using the information provided below:\n",
    "\n",
    "Input Parameters:\n",
    "-----------------\n",
    "INTENT: {intent}\n",
    "EXTRACTED_ENTITIES: {entities}\n",
    "MAPPED_TABLES: {tables}\n",
    "\n",
    "User Query:\n",
    "-----------\n",
    "{user_query}\n",
    "\n",
    "Table Schemas:\n",
    "--------------\n",
    "{filtered_table_schemas_text}\n",
    "\n",
    "Note: \n",
    "- Please prefix the table names with `flight_reservations`.\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "669fd020-d308-4f72-b8ec-eb9d0e957254",
   "metadata": {},
   "outputs": [],
   "source": [
    "human_message = HumanMessagePromptTemplate.from_template(human_template)\n",
    "messages.append(human_message)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "b51471c9-3fe7-4b6e-843f-6a90c771b02d",
   "metadata": {},
   "outputs": [],
   "source": [
    "chat_prompt = ChatPromptTemplate.from_messages(messages)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "b642faa4-844a-4018-a903-70fd9a4b9bbe",
   "metadata": {},
   "outputs": [],
   "source": [
    "request = chat_prompt.format_prompt(intent=intent, \n",
    "                                    entities=entities, \n",
    "                                    tables=tables, \n",
    "                                    user_query=user_query, \n",
    "                                    filtered_table_schemas_text=filtered_table_schemas_text).to_messages()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "43f6e85b-999b-419d-89e2-e709d557e618",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "SELECT * \n",
      "FROM flight_reservations.reservations \n",
      "WHERE reservation_datetime BETWEEN '2023-10-10' AND '2023-10-15'\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 4.39 ms, sys: 2.8 ms, total: 7.19 ms\n",
      "Wall time: 1.67 s\n"
     ]
    }
   ],
   "source": [
    "%%time \n",
    "\n",
    "response = llm(request)\n",
    "sql = '\\n'.join(response.content.strip().split('\\n')[1:-1])\n",
    "logger.info(sql)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "92250742-c3f8-4d3f-b45e-f60a15fe861c",
   "metadata": {},
   "source": [
    "### Step 6: Execute the generated SQL query in BigQuery"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "df6ed74e-d00b-4799-b759-394501988a37",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>reservation_id</th>\n",
       "      <th>customer_id</th>\n",
       "      <th>flight_id</th>\n",
       "      <th>reservation_datetime</th>\n",
       "      <th>status</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>6</td>\n",
       "      <td>6</td>\n",
       "      <td>6</td>\n",
       "      <td>2023-10-10 10:00:00</td>\n",
       "      <td>Confirmed</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>7</td>\n",
       "      <td>6</td>\n",
       "      <td>7</td>\n",
       "      <td>2023-10-12 11:30:00</td>\n",
       "      <td>Confirmed</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   reservation_id  customer_id  flight_id reservation_datetime     status\n",
       "0               6            6          6  2023-10-10 10:00:00  Confirmed\n",
       "1               7            6          7  2023-10-12 11:30:00  Confirmed"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = bq.query(sql).to_dataframe()\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "62ebdb0a",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = df.to_markdown(index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "23492856-4ec6-44ec-84ac-543d844bb5a6",
   "metadata": {},
   "source": [
    "### Step 7: Transform SQL results into a human friendly response (Optional)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "e412c470-ec03-4768-b41d-63d80d95376a",
   "metadata": {},
   "outputs": [],
   "source": [
    "messages = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "8750e86b-3157-42b2-a6ce-2841aabde957",
   "metadata": {},
   "outputs": [],
   "source": [
    "template = \"You are a travel assistant chatbot that can help people make flight reservations.\"\n",
    "system_message_prompt = SystemMessagePromptTemplate.from_template(template)\n",
    "messages.append(system_message_prompt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "b647788e-5876-4731-95da-08afec9b0d84",
   "metadata": {},
   "outputs": [],
   "source": [
    "human_template = \"\"\"User's Question:\n",
    "----------------\n",
    "{user_query}\n",
    "\n",
    "BigQuery Result:\n",
    "----------------\n",
    "{bq_response}\n",
    "\n",
    "Task:\n",
    "-----\n",
    "Please convert the above query result into a human-readable format.\n",
    "\n",
    "IMPORTANT Notes:\n",
    "----------------\n",
    "- The response should be courteous and human-friendly.\n",
    "- If the answer doesn't require a tabular structure, avoid using it.\"\"\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "id": "6c00a1d1-1afd-4155-9b6a-0f0f902041e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "human_message = HumanMessagePromptTemplate.from_template(human_template)\n",
    "messages.append(human_message)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "ecc333d7-232f-4311-b4e0-de1b0d316a67",
   "metadata": {},
   "outputs": [],
   "source": [
    "chat_prompt = ChatPromptTemplate.from_messages(messages)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "7f11fe13-110d-451b-a4cf-a00fd2599761",
   "metadata": {},
   "outputs": [],
   "source": [
    "request = chat_prompt.format_prompt(user_query=user_query,\n",
    "                                    bq_response=df).to_messages()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "id": "07361137-d44f-4960-ad93-770e918db663",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Here is a list of all flight reservations from October 10th to October 15th, 2023:\n",
      "\n",
      "* Reservation ID: 6, Customer ID: 6, Flight ID: 6, Reservation Datetime: 2023-10-10 10:00:00, Status: Confirmed\n",
      "* Reservation ID: 7, Customer ID: 6, Flight ID: 7, Reservation Datetime: 2023-10-12 11:30:00, Status: Confirmed\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 5.45 ms, sys: 3.95 ms, total: 9.41 ms\n",
      "Wall time: 5.97 s\n"
     ]
    }
   ],
   "source": [
    "%%time \n",
    "\n",
    "response = llm(request)\n",
    "output = response.content.strip()\n",
    "logger.info(output)"
   ]
  }
 ],
 "metadata": {
  "environment": {
   "kernel": "python3",
   "name": "common-cpu.m110",
   "type": "gcloud",
   "uri": "gcr.io/deeplearning-platform-release/base-cpu:m110"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
